import plotly.express as px
import plotly.graph_objects as go
from IPython.display import HTML
# 1) Build the DataFrame
df_plot = features.copy()
df_plot['Cluster'] = eda.loc[features.index, 'Cluster']
# 2) Compute centroids in original units
centroids = kmeans.cluster_centers_
centroids_x = centroids[:, 0] * X.std(axis=0)[0] + X.mean(axis=0)[0]
centroids_y = centroids[:, 1] * X.std(axis=0)[1] + X.mean(axis=0)[1]
# 3) Create an interactive Plotly Figure
fig = px.scatter(
df_plot,
x='SALARY',
y='MAX_YEARS_EXPERIENCE',
color='Cluster',
title="KMeans Clustering by Salary and Max Years Experience",
labels={
'SALARY': 'Salary',
'MAX_YEARS_EXPERIENCE': 'Max Years Experience',
'Cluster': 'Cluster'
},
width=800,
height=500,
)
# add centroids
fig.add_trace(
go.Scatter(
x=centroids_x,
y=centroids_y,
mode='markers',
marker=dict(symbol='x', size=18, color='black',
line=dict(width=2, color='white')),
name='Centroids'
)
)
fig.update_layout(
autosize=True,
height=800,
margin=dict(l=20, r=20, t=50, b=20)
)
fig.write_html(
"figures/analytics_plot1.html",
include_plotlyjs=True,
full_html=False
)